import numpy as np
import torch
import torchvision.transforms

def log10(t):
    """
    Calculates the base-10 log of each element in t.
    @param t: The tensor from which to calculate the base-10 log.
    @return: A tensor with the base-10 log of each element in t.
    """

    numerator = torch.log(t)
    denominator = torch.log(torch.FloatTensor([10.])).cuda()
    return numerator / denominator


def psnr_error(gen_frames, gt_frames):
    """
    Computes the Peak Signal to Noise Ratio error between the generated images and the ground
    truth images.
    @param gen_frames: A tensor of shape [batch_size, height, width, 3]. The frames generated by the
                       generator model.
    @param gt_frames: A tensor of shape [batch_size, height, width, 3]. The ground-truth frames for
                      each frame in gen_frames.
    @return: A scalar tensor. The mean Peak Signal to Noise Ratio error over each frame in the
             batch.
    """
    shape = list(gen_frames.shape)
    num_pixels = (shape[1] * shape[2] * shape[3])
    gt_frames = (gt_frames + 1.0) / 2.0
    gen_frames = (gen_frames + 1.0) / 2.0
    square_diff = (gt_frames - gen_frames)**2

    batch_errors = 10 * log10(1. / ((1. / num_pixels) * torch.sum(square_diff, [1, 2, 3])))
    return torch.mean(batch_errors)

#for [B,C,W,H]
def bgr_gray(input_tensor):
    B=input_tensor[:,0].view(input_tensor.size()[0],1,input_tensor.size()[2],input_tensor.size()[3])
    G=input_tensor[:,1].view(input_tensor.size()[0],1,input_tensor.size()[2],input_tensor.size()[3])
    R=input_tensor[:,2].view(input_tensor.size()[0],1,input_tensor.size()[2],input_tensor.size()[3])
    gray_tensor=B*0.114+G*0.587+R*0.299
    return gray_tensor

def diff_mask(gen_frames, gt_frames, min_value=-1, max_value=1):
    # normalize to [0, 1]
    delta = max_value - min_value
    gen_frames = (gen_frames - min_value) / delta
    gt_frames = (gt_frames - min_value) / delta

    gen_gray_frames = bgr_gray(gen_frames)
    gt_gray_frames = bgr_gray(gt_frames)

    diff = torch.abs(gen_gray_frames - gt_gray_frames)
    return diff

